from geartrain.backends.base import Backend
import tensorflow as tf
tf = tf.compat.v1
tf.disable_v2_behavior()

class TF_Backend(Backend):
    def __init__(self, device, model_path, input_name: list=None, output_name: list=None, **kwargs) -> None:
        self.model_path = model_path
        self._device = "/CPU:0" if device == -1 else f"/device:GPU:{device}"
        self.input_name = input_name
        self.output_name = output_name
        super().__init__(device, "tf")
        
    def _load_model(self) -> None:
        self.model = self._load_pb_model(self.model_path)
        
        self.config = tf.ConfigProto()
        self.config.gpu_options.per_process_gpu_memory_fraction = 0.6
        
        self.input_tensors = [self.model.get_tensor_by_name(in_name) for in_name in self.input_name]
        self.output_tensors = [self.model.get_tensor_by_name(out_name) for out_name in self.output_name]
        if len(self.output_tensors)==1:
            self.output_tensors = self.output_tensors[0]
            
    def _load_pb_model(self, model_path):
        # Load the protobuf model from the .pb file
        graph = tf.Graph()
        with graph.as_default():
            with tf.gfile.GFile(model_path, "rb") as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, name="")
        
        return graph

    def _forward(self, model_inputs):
        
        if isinstance(model_inputs, dict):
            model_inputs = model_inputs["model_inputs"]
        
        with tf.Session(graph=self.model, config=self.config) as sess:
            if isinstance(model_inputs, list):
                feed_dict = {}
                for i, input_t in enumerate(self.input_tensors):
                    feed_dict[input_t] = model_inputs[i]
                
            else:
                feed_dict = {self.input_tensors[0]: model_inputs}
            outputs = sess.run(self.output_tensors, feed_dict=feed_dict)
            
        return outputs